HDU 5756 Boss Bo(主席树、标记永久化)

题意:

$给定N\le 5\times 10^4个点的一棵树,Q\le 10^5$
$定义一个点是好点,当且仅当他所有祖先都不是坏点$
$每次询问指定K个点为坏点,查询1个点P到所有好点的$
$op=1:距离和$
$op=2:最小距离$
$op=3:最大距离$

分析:

$定义换句话说,如果一个点是坏点,那么子树都是坏点$
$剩下的就是官方题解的做法了$

//
//  Created by TaoSama on 2016-08-13
//  Copyright (c) 2016 TaoSama. All rights reserved.
//
#pragma comment(linker, "/STACK:102400000,102400000")
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <iomanip>
#include <iostream>
#include <map>
#include <queue>
#include <string>
#include <set>
#include <vector>

using namespace std;
#define pr(x) cout << #x << " = " << x << "  "
#define prln(x) cout << #x << " = " << x << endl
const int N = 5e4 + 10, INF = 0x3f3f3f3f, MOD = 1e9 + 7;

typedef long long LL;
const LL LLINF = 0x3f3f3f3f3f3f3f3fLL;

int n, q;
vector<int> G[N];

int dep[N], L[N], R[N], vs[N], dfsNum;
void dfs(int u, int fa) {
    L[u] = ++dfsNum;
    vs[dfsNum] = u;
    for(int v : G[u]) {
        if(v == fa) continue;
        dep[v] = dep[u] + 1;
        dfs(v, u);
    }
    R[u] = dfsNum;
}

int root[N];
struct PersistentSegTree {
    static const int M = N * 2 * 20;
    int sz;
    struct Node {
        int ls, rs;
        LL addv, maxv, minv, sum;
        void add(LL v, int len) {
            addv += v;
            sum += v * len;
            maxv += v;
            minv += v;
        }
        void see() {
            pr(addv); pr(maxv); pr(minv); prln(sum);
        }
    } dat[M];
    int newNode(int rt) {
        dat[++sz] = dat[rt];
        return sz;
    }
    void init() {
        sz = 0;
        memset(&dat[0], 0, sizeof dat[0]);
    }
    void up(int rt, int len) {
        dat[rt].sum = dat[dat[rt].ls].sum + dat[dat[rt].rs].sum + dat[rt].addv * len;
        dat[rt].minv = min(dat[dat[rt].ls].minv, dat[dat[rt].rs].minv) + dat[rt].addv;
        dat[rt].maxv = max(dat[dat[rt].ls].maxv, dat[dat[rt].rs].maxv) + dat[rt].addv;
    }
    void build(int l, int r, int& rt) {
        rt = newNode(0);
        if(l == r) {
            dat[rt].add(dep[vs[l]], 1);
            return;
        }
        int m = l + r >> 1;
        build(l, m, dat[rt].ls);
        build(m + 1, r, dat[rt].rs);
        up(rt, r - l + 1);
    }
    void update(int L, int R, int v, int l, int r, int& rt) {
        rt = newNode(rt);
        if(L <= l && r <= R) {
            dat[rt].add(v, r - l + 1);
            return;
        }
        int m = l + r >> 1;
        if(L <= m) update(L, R, v, l, m, dat[rt].ls);
        if(R > m) update(L, R, v, m + 1, r, dat[rt].rs);
        up(rt, r - l + 1);
    }
    LL query1(int L, int R, LL z, int l, int r, int rt) {
        if(L <= l && r <= R) return dat[rt].sum + z * (r - l + 1);
        int m = l + r >> 1;
        LL ret = 0;
        if(L <= m) ret += query1(L, R, z + dat[rt].addv, l, m, dat[rt].ls);
        if(R > m) ret += query1(L, R, z + dat[rt].addv, m + 1, r, dat[rt].rs);
        return ret;
    }
    LL query2(int L, int R, LL z, int l, int r, int rt) {
        if(L <= l && r <= R) return dat[rt].minv + z;
        int m = l + r >> 1;
        LL ret = LLINF;
        if(L <= m) ret = min(ret, query2(L, R, z + dat[rt].addv, l, m, dat[rt].ls));
        if(R > m) ret = min(ret, query2(L, R, z + dat[rt].addv, m + 1, r, dat[rt].rs));
        return ret;
    }
    LL query3(int L, int R, LL z, int l, int r, int rt) {
        if(L <= l && r <= R) return dat[rt].maxv + z;
        int m = l + r >> 1;
        LL ret = -LLINF;
        if(L <= m) ret = max(ret, query3(L, R, z + dat[rt].addv, l, m, dat[rt].ls));
        if(R > m) ret = max(ret, query3(L, R, z + dat[rt].addv, m + 1, r, dat[rt].rs));
        return ret;
    }
} T;

void dfs2(int u, int fa) {
    if(u == 1) {
        T.init();
        T.build(1, n, root[1]);
    } else {
        root[u] = root[fa];
        T.update(1, n, 1, 1, n, root[u]);
        T.update(L[u], R[u], -2, 1, n, root[u]);
    }
    for(int v : G[u]) {
        if(v == fa) continue;
        dfs2(v, u);
    }
}

int main() {
#ifdef LOCAL
    freopen("C:\\Users\\TaoSama\\Desktop\\in.txt", "r", stdin);
//  freopen("C:\\Users\\TaoSama\\Desktop\\out.txt","w",stdout);
#endif
    ios_base::sync_with_stdio(0);

    while(scanf("%d%d", &n, &q) == 2) {
        for(int i = 1; i <= n; ++i) G[i].clear();
        for(int i = 1; i < n; ++i) {
            int u, v; scanf("%d%d", &u, &v);
            G[u].push_back(v);
            G[v].push_back(u);
        }

        dfsNum = 0;
        dfs(1, 0);
        dfs2(1, 0);

        LL last = 0;
        for(int z = 1; z <= q; ++z) {
            int k, p, op; scanf("%d%d%d", &k, &p, &op);
            p = (p + last) % n + 1;

            vector<pair<int, int> > seg(k);
            for(int i = 0; i < k; ++i) {
                int x; scanf("%d", &x);
                seg[i] = {L[x], R[x]};
            }
            sort(seg.begin(), seg.end());

            k = 0;
            for(int i = 0, j; i < seg.size(); i = j) {
                int r = seg[i].second;
                for(j = i + 1; j < seg.size() && seg[j].first <= seg[i].second; ++j)
                    r = max(r, seg[j].second);
                seg[k++] = {seg[i].first, r};
            }
            seg.resize(k);
            seg.push_back({n + 1, n + 1});
            if(seg[0] == make_pair(1, n)) {
                puts("-1");
                last = 0;
                continue;
            }

            if(op == 1) {
                last = 0;
                for(int i = 0, l = 1; i < seg.size(); ++i) {
                    int r = seg[i].first - 1;
                    if(l <= r) {
                        last += T.query1(l, r, 0, 1, n, root[p]);
                    }
                    l = seg[i].second + 1;
                }
                printf("%I64d\n", last);
            } else if(op == 2) {
                last = LLINF;
                for(int i = 0, l = 1; i < seg.size(); ++i) {
                    int r = seg[i].first - 1;
                    if(l <= r) {
                        last = min(last, T.query2(l, r, 0, 1, n, root[p]));
                    }
                    l = seg[i].second + 1;
                }
                printf("%I64d\n", last);
            } else {
                last = -LLINF;
                for(int i = 0, l = 1; i < seg.size(); ++i) {
                    int r = seg[i].first - 1;
                    if(l <= r) {
                        last = max(last, T.query3(l, r, 0, 1, n, root[p]));
                    }
                    l = seg[i].second + 1;
                }
                printf("%I64d\n", last);
            }
        }
    }
    return 0;
}